#!/usr/bin/env python3
"""
Phase 5: Reviewer #2 Response — Robustness & Validation
Five parts addressing reviewer concerns:
  A. Income balance validation (WDI BOP primary/secondary income)
  B. FH long-difference robustness (5yr, 10yr non-overlapping averages)
  C. FH income group robustness (GDP/capita control + tercile splits)
  D. KAOPEN interaction with ER regime & trade openness
  E. NFA split detail (creditor/debtor on income balance)

Output: extensions/output/tables/phase5_*.md (5 files)
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import requests
import time
import io

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
TRILEMMA_DIR = ROOT_DIR / "trilemma"

OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)
OUT_DATA = PROJECT_DIR / "data" / "processed"

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS


def load_panel():
    path = OUT_DATA / "extensions_panel.csv"
    df = pd.read_csv(path)
    df = df[df["year"] <= 2024].copy()
    return df


def stars(p):
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.1:
        return "*"
    return ""


def run_model(df, dep_var, regressors, label):
    """Run PanelGLS and return results dict."""
    cols = [dep_var] + regressors
    sub = df.dropna(subset=cols).copy()
    if len(sub) < 100:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    model = PanelGLS()
    model.fit(sub[dep_var].values, sub[regressors].values,
              sub["iso3"].values, sub["year"].values)
    print(f"\n  {label}")
    model.summary(feature_names=regressors)

    res = model.to_dataframe(feature_names=regressors)
    res["model"] = label
    res["dep_var"] = dep_var
    res["n_obs"] = model.n_obs
    res["n_countries"] = model.n_countries
    res["r_squared"] = model.r_squared
    res["rho"] = model.rho
    return res


# ═══════════════════════════════════════════════════════════════════════════
# Part A: Income Balance Validation
# ═══════════════════════════════════════════════════════════════════════════

def download_wdi_indicator(indicator, var_name, force=False):
    """Download a WDI indicator from World Bank API."""
    cache_path = OUT_DATA / f"wdi_{var_name}_raw.csv"
    if cache_path.exists() and not force:
        print(f"  Using cached {var_name}: {cache_path}")
        return pd.read_csv(cache_path)

    print(f"  Downloading WDI {indicator} ({var_name})...")
    base_url = "https://api.worldbank.org/v2/country/all/indicator"
    all_rows = []
    page = 1

    while True:
        url = (f"{base_url}/{indicator}?format=json&per_page=1000"
               f"&date=1970:2024&page={page}")
        resp = requests.get(url, timeout=60)
        resp.raise_for_status()
        data = resp.json()

        if len(data) < 2 or data[1] is None:
            break

        for obs in data[1]:
            if obs["value"] is not None and obs.get("countryiso3code"):
                all_rows.append({
                    "iso3": obs["countryiso3code"],
                    "year": int(obs["date"]),
                    var_name: float(obs["value"]),
                })

        total_pages = data[0]["pages"]
        print(f"    Page {page}/{total_pages} ({len(all_rows)} obs)")
        if page >= total_pages:
            break
        page += 1
        time.sleep(0.5)

    df = pd.DataFrame(all_rows)
    df.to_csv(cache_path, index=False)
    print(f"  Downloaded {len(df):,} obs for {df['iso3'].nunique()} countries")
    return df


def part_a_income_validation(df):
    """Validate income_balance_gdp residual against official BOP data."""
    print("\n" + "=" * 70)
    print("PART A: Income Balance Validation")
    print("=" * 70)

    # Download primary income (net), secondary income (net), and GDP
    primary = download_wdi_indicator("BN.GSR.FCTY.CD", "primary_income_usd")
    secondary = download_wdi_indicator("BN.TRF.CURR.CD", "secondary_income_usd")
    gdp_usd = download_wdi_indicator("NY.GDP.MKTP.CD", "gdp_current_usd")

    # Merge and compute ratios
    bop = primary.merge(secondary, on=["iso3", "year"], how="outer")
    bop = bop.merge(gdp_usd, on=["iso3", "year"], how="left")
    bop = bop.dropna(subset=["gdp_current_usd"])
    bop["primary_income_gdp"] = bop["primary_income_usd"] / bop["gdp_current_usd"] * 100
    bop["secondary_income_gdp"] = bop["secondary_income_usd"] / bop["gdp_current_usd"] * 100
    bop["bop_income_gdp"] = bop["primary_income_gdp"].fillna(0) + bop["secondary_income_gdp"].fillna(0)

    # Merge with our panel
    merged = df.merge(bop[["iso3", "year", "primary_income_gdp",
                           "secondary_income_gdp", "bop_income_gdp"]],
                      on=["iso3", "year"], how="left")

    # Correlation analysis
    both = merged.dropna(subset=["income_balance_gdp", "bop_income_gdp"])
    n_both = len(both)
    n_countries = both["iso3"].nunique()

    corr_total = both["income_balance_gdp"].corr(both["bop_income_gdp"])
    corr_primary = both["income_balance_gdp"].corr(
        both["primary_income_gdp"].fillna(0))
    corr_secondary = both["income_balance_gdp"].corr(
        both["secondary_income_gdp"].fillna(0))

    print(f"\n  Validation sample: N={n_both:,}, {n_countries} countries")
    print(f"  Corr(residual, BOP total): {corr_total:.3f}")
    print(f"  Corr(residual, primary):   {corr_primary:.3f}")
    print(f"  Corr(residual, secondary): {corr_secondary:.3f}")

    # Re-run M3 (Z → income_balance) using official BOP income
    demo_vars = ["Z_1", "Z_2", "Z_3"]
    eba_controls = ["fiscal_bal_gdp", "nfa_gdp_lag", "log_rel_opw", "kaopen"]
    eba_controls = [c for c in eba_controls if c in merged.columns
                    and merged[c].notna().sum() > 200]
    base_vars = demo_vars + eba_controls

    results = []

    # M3 original: Z → income_balance_gdp (residual)
    r = run_model(merged, "income_balance_gdp", base_vars,
                  "M3-orig: Z → income_balance (residual)")
    if r is not None:
        results.append(r)

    # M3-bop: Z → bop_income_gdp (official)
    r = run_model(merged, "bop_income_gdp", base_vars,
                  "M3-bop: Z → BOP income (official)")
    if r is not None:
        results.append(r)

    # Z → primary_income_gdp
    r = run_model(merged, "primary_income_gdp", base_vars,
                  "M3-primary: Z → primary income")
    if r is not None:
        results.append(r)

    # Z → secondary_income_gdp
    r = run_model(merged, "secondary_income_gdp", base_vars,
                  "M3-secondary: Z → secondary income")
    if r is not None:
        results.append(r)

    # Write output
    lines = ["# Part A: Income Balance Validation", ""]
    lines.append("## Correlation Between Residual and Official BOP Income")
    lines.append("")
    lines.append(f"- Validation sample: N = {n_both:,}, {n_countries} countries")
    lines.append(f"- Corr(residual income_balance, BOP total income): **{corr_total:.3f}**")
    lines.append(f"- Corr(residual, primary income only): {corr_primary:.3f}")
    lines.append(f"- Corr(residual, secondary income only): {corr_secondary:.3f}")
    lines.append("")

    if results:
        rdf = pd.concat(results, ignore_index=True)
        lines.append("## Regression Comparison: Z → Income Balance")
        lines.append("")
        lines.append("| Model | Dep Var | N | Countries | R² | ρ |")
        lines.append("|-------|---------|---|-----------|----|----|")
        for m in rdf["model"].unique():
            row = rdf[rdf["model"] == m].iloc[0]
            lines.append(f"| {m} | {row['dep_var']} | {row['n_obs']:,} | "
                         f"{row['n_countries']} | {row['r_squared']:.3f} | "
                         f"{row['rho']:.3f} |")

        lines.append("")
        lines.append("## Z Coefficients Across Income Measures")
        lines.append("")
        lines.append("| Variable | Model | Coef | SE | p-value |")
        lines.append("|----------|-------|------|----|---------|")
        for v in ["Z_1", "Z_2", "Z_3"]:
            sub = rdf[rdf["variable"] == v]
            for _, row in sub.iterrows():
                if pd.isna(row["p_value"]):
                    continue
                sig = stars(row["p_value"])
                lines.append(f"| {v} | {row['model']} | "
                             f"{row['coefficient']:.4f}{sig} | "
                             f"{row['std_error']:.4f} | {row['p_value']:.4f} |")

    md_path = OUT_TABLES / "phase5_income_validation.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Output: {md_path}")
    return corr_total, n_both, n_countries


# ═══════════════════════════════════════════════════════════════════════════
# Part B: FH Long-Difference Robustness
# ═══════════════════════════════════════════════════════════════════════════

def part_b_fh_long_diff(df):
    """FH regressions on 5-year and 10-year non-overlapping averages."""
    print("\n" + "=" * 70)
    print("PART B: FH Long-Difference Robustness")
    print("=" * 70)

    demo_vars = ["Z_1", "Z_2", "Z_3"]

    # Create interaction vars fresh
    df = df.copy()
    df["savings_x_Z1"] = df["savings_gdp"] * df["Z_1"]
    df["savings_x_Z2"] = df["savings_gdp"] * df["Z_2"]
    df["savings_x_Z3"] = df["savings_gdp"] * df["Z_3"]

    all_results = []

    for horizon, label in [(1, "Annual"), (5, "5yr avg"), (10, "10yr avg")]:
        print(f"\n  --- {label} ---")

        if horizon == 1:
            hdf = df.copy()
        else:
            # Non-overlapping averages
            df["period"] = ((df["year"] - df["year"].min()) // horizon).astype(int)
            avg_cols = (["savings_gdp", "investment_gdp"] + demo_vars
                        + ["savings_x_Z1", "savings_x_Z2", "savings_x_Z3"])
            avg_cols = [c for c in avg_cols if c in df.columns]
            hdf = (df.groupby(["iso3", "period"])[avg_cols]
                   .mean().reset_index())
            # Use period index as "year" so consecutive periods are adjacent
            hdf["year"] = hdf["period"]

        # M5a: Bare FH
        r = run_model(hdf, "investment_gdp", ["savings_gdp"],
                      f"M5a-{label}: FH baseline")
        if r is not None:
            all_results.append(r)

        # M5c: FH + S×Z interactions
        fh_int_vars = (["savings_gdp"] + demo_vars
                       + ["savings_x_Z1", "savings_x_Z2", "savings_x_Z3"])
        fh_int_vars = [v for v in fh_int_vars if v in hdf.columns]
        r = run_model(hdf, "investment_gdp", fh_int_vars,
                      f"M5c-{label}: FH + S×Z")
        if r is not None:
            all_results.append(r)

    # Write output
    lines = ["# Part B: FH Long-Difference Robustness", ""]
    lines.append("Tests whether the S×Z interaction survives in long-difference "
                 "specifications to address persistence/nonstationarity concerns.")
    lines.append("")

    if all_results:
        rdf = pd.concat(all_results, ignore_index=True)

        lines.append("## Model Comparison")
        lines.append("")
        lines.append("| Model | N | Countries | R² | ρ |")
        lines.append("|-------|---|-----------|----|----|")
        for m in rdf["model"].unique():
            row = rdf[rdf["model"] == m].iloc[0]
            lines.append(f"| {m} | {row['n_obs']:,} | "
                         f"{row['n_countries']} | {row['r_squared']:.3f} | "
                         f"{row['rho']:.3f} |")

        lines.append("")
        lines.append("## Key Coefficients")
        lines.append("")
        lines.append("| Variable | Model | Coef | SE | p-value |")
        lines.append("|----------|-------|------|----|---------|")
        for v in ["savings_gdp", "savings_x_Z1", "savings_x_Z2", "savings_x_Z3"]:
            sub = rdf[rdf["variable"] == v]
            for _, row in sub.iterrows():
                if pd.isna(row["p_value"]):
                    continue
                sig = stars(row["p_value"])
                lines.append(f"| {v} | {row['model']} | "
                             f"{row['coefficient']:.4f}{sig} | "
                             f"{row['std_error']:.4f} | {row['p_value']:.4f} |")

    md_path = OUT_TABLES / "phase5_fh_long_diff.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Output: {md_path}")


# ═══════════════════════════════════════════════════════════════════════════
# Part C: FH Income Group Robustness
# ═══════════════════════════════════════════════════════════════════════════

def part_c_fh_income_robustness(df):
    """FH regressions with income controls and by income tercile."""
    print("\n" + "=" * 70)
    print("PART C: FH Income Group Robustness")
    print("=" * 70)

    df = df.copy()
    demo_vars = ["Z_1", "Z_2", "Z_3"]
    df["savings_x_Z1"] = df["savings_gdp"] * df["Z_1"]
    df["savings_x_Z2"] = df["savings_gdp"] * df["Z_2"]
    df["savings_x_Z3"] = df["savings_gdp"] * df["Z_3"]
    df["log_gdp_pc"] = np.log(df["gdp_pc_ppp"].clip(lower=100))

    fh_int_vars = (["savings_gdp"] + demo_vars
                   + ["savings_x_Z1", "savings_x_Z2", "savings_x_Z3"])

    all_results = []

    # M5c baseline (for comparison)
    r = run_model(df, "investment_gdp", fh_int_vars,
                  "M5c: FH + S×Z (baseline)")
    if r is not None:
        all_results.append(r)

    # M5c + log(GDP/capita) + growth
    income_controls = []
    if "log_gdp_pc" in df.columns and df["log_gdp_pc"].notna().sum() > 500:
        income_controls.append("log_gdp_pc")
    if "rgdp_growth" in df.columns and df["rgdp_growth"].notna().sum() > 500:
        income_controls.append("rgdp_growth")

    if income_controls:
        r = run_model(df, "investment_gdp", fh_int_vars + income_controls,
                      "M5c+income: FH + S×Z + income controls")
        if r is not None:
            all_results.append(r)

    # Split by income tercile
    if "gdp_pc_ppp" in df.columns:
        tercile_bounds = df.groupby("year")["gdp_pc_ppp"].quantile([1/3, 2/3])
        tercile_dict = tercile_bounds.unstack()
        df["income_tercile"] = np.nan
        for year in df["year"].unique():
            if year not in tercile_dict.index:
                continue
            mask = df["year"] == year
            t1 = tercile_dict.loc[year, 1/3]
            t2 = tercile_dict.loc[year, 2/3]
            df.loc[mask & (df["gdp_pc_ppp"] <= t1), "income_tercile"] = 1
            df.loc[mask & (df["gdp_pc_ppp"] > t1) & (df["gdp_pc_ppp"] <= t2),
                   "income_tercile"] = 2
            df.loc[mask & (df["gdp_pc_ppp"] > t2), "income_tercile"] = 3

        for terc, label in [(1, "low-income"), (2, "middle-income"),
                            (3, "high-income")]:
            sub = df[df["income_tercile"] == terc].copy()
            r = run_model(sub, "investment_gdp", fh_int_vars,
                          f"M5c-{label}: FH + S×Z")
            if r is not None:
                all_results.append(r)

    # Write output
    lines = ["# Part C: FH Income Group Robustness", ""]
    lines.append("Tests whether S×Z heterogeneity is a functional-form "
                 "artifact of development level.")
    lines.append("")

    if all_results:
        rdf = pd.concat(all_results, ignore_index=True)

        lines.append("## Model Comparison")
        lines.append("")
        lines.append("| Model | N | Countries | R² | ρ |")
        lines.append("|-------|---|-----------|----|----|")
        for m in rdf["model"].unique():
            row = rdf[rdf["model"] == m].iloc[0]
            lines.append(f"| {m} | {row['n_obs']:,} | "
                         f"{row['n_countries']} | {row['r_squared']:.3f} | "
                         f"{row['rho']:.3f} |")

        lines.append("")
        lines.append("## Key Coefficients")
        lines.append("")
        lines.append("| Variable | Model | Coef | SE | p-value |")
        lines.append("|----------|-------|------|----|---------|")
        for v in ["savings_gdp", "savings_x_Z1", "savings_x_Z2", "savings_x_Z3"]:
            sub = rdf[rdf["variable"] == v]
            for _, row in sub.iterrows():
                if pd.isna(row["p_value"]):
                    continue
                sig = stars(row["p_value"])
                lines.append(f"| {v} | {row['model']} | "
                             f"{row['coefficient']:.4f}{sig} | "
                             f"{row['std_error']:.4f} | {row['p_value']:.4f} |")

    md_path = OUT_TABLES / "phase5_fh_income_robustness.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Output: {md_path}")


# ═══════════════════════════════════════════════════════════════════════════
# Part D: KAOPEN Interaction with ER Regime
# ═══════════════════════════════════════════════════════════════════════════

def part_d_kaopen_robustness(df):
    """Re-run KAOPEN interactions adding trade_openness and ER regime."""
    print("\n" + "=" * 70)
    print("PART D: KAOPEN Interaction — ER Regime & Trade Openness")
    print("=" * 70)

    df = df.copy()

    # Merge trilemma data for ers_index
    tri_path = TRILEMMA_DIR / "data" / "processed" / "trilemma_panel.csv"
    if tri_path.exists():
        tri = pd.read_csv(tri_path, usecols=["iso3", "year", "ers_index"])
        df = df.merge(tri, on=["iso3", "year"], how="left")
        n_ers = df["ers_index"].notna().sum()
        print(f"  Merged ers_index: {n_ers:,} obs")
    else:
        print("  WARNING: trilemma_panel.csv not found, skipping ER regime")
        df["ers_index"] = np.nan

    demo_vars = ["Z_1", "Z_2", "Z_3"]
    eba_controls = ["fiscal_bal_gdp", "nfa_gdp_lag", "log_rel_opw", "kaopen"]
    eba_controls = [c for c in eba_controls if c in df.columns
                    and df[c].notna().sum() > 200]
    interaction_vars = ["Z_1_x_kaopen", "Z_2_x_kaopen", "Z_3_x_kaopen"]
    interaction_vars = [c for c in interaction_vars if c in df.columns
                        and df[c].notna().sum() > 200]

    base_vars = demo_vars + eba_controls + interaction_vars

    # Create additional interactions
    df["Z_1_x_trade_open"] = df["Z_1"] * df["trade_openness"]
    df["Z_2_x_trade_open"] = df["Z_2"] * df["trade_openness"]
    df["Z_3_x_trade_open"] = df["Z_3"] * df["trade_openness"]
    df["Z_1_x_ers"] = df["Z_1"] * df["ers_index"]
    df["Z_2_x_ers"] = df["Z_2"] * df["ers_index"]
    df["Z_3_x_ers"] = df["Z_3"] * df["ers_index"]

    all_results = []

    # Baseline: Z×KAOPEN on trade balance (reproduce M4)
    r = run_model(df, "trade_balance_gdp", base_vars,
                  "M4-base: Z×KAOPEN → Trade Bal")
    if r is not None:
        all_results.append(r)

    # Baseline: Z×KAOPEN on income balance (reproduce M5)
    r = run_model(df, "income_balance_gdp", base_vars,
                  "M5-base: Z×KAOPEN → Income Bal")
    if r is not None:
        all_results.append(r)

    # + trade_openness interactions
    trade_int = ["trade_openness", "Z_1_x_trade_open",
                 "Z_2_x_trade_open", "Z_3_x_trade_open"]
    trade_int = [v for v in trade_int if v in df.columns
                 and df[v].notna().sum() > 200]

    if trade_int:
        extended_vars = base_vars + trade_int
        r = run_model(df, "trade_balance_gdp", extended_vars,
                      "M4+trade: Z×KAOPEN + Z×trade_open → Trade")
        if r is not None:
            all_results.append(r)

        r = run_model(df, "income_balance_gdp", extended_vars,
                      "M5+trade: Z×KAOPEN + Z×trade_open → Income")
        if r is not None:
            all_results.append(r)

    # + ER regime interactions
    ers_int = ["ers_index", "Z_1_x_ers", "Z_2_x_ers", "Z_3_x_ers"]
    ers_int = [v for v in ers_int if v in df.columns
               and df[v].notna().sum() > 200]

    if ers_int:
        ers_vars = base_vars + ers_int
        r = run_model(df, "trade_balance_gdp", ers_vars,
                      "M4+ERS: Z×KAOPEN + Z×ERS → Trade")
        if r is not None:
            all_results.append(r)

        r = run_model(df, "income_balance_gdp", ers_vars,
                      "M5+ERS: Z×KAOPEN + Z×ERS → Income")
        if r is not None:
            all_results.append(r)

    # Write output
    lines = ["# Part D: KAOPEN Interaction — ER Regime & Trade Openness Robustness",
             ""]
    lines.append("Tests whether the Z×KAOPEN compositional shift between "
                 "trade and income balance is robust to controlling for "
                 "real-side channels (trade openness, ER flexibility).")
    lines.append("")

    if all_results:
        rdf = pd.concat(all_results, ignore_index=True)

        lines.append("## Model Comparison")
        lines.append("")
        lines.append("| Model | Dep Var | N | Countries | R² | ρ |")
        lines.append("|-------|---------|---|-----------|----|----|")
        for m in rdf["model"].unique():
            row = rdf[rdf["model"] == m].iloc[0]
            lines.append(f"| {m} | {row['dep_var']} | {row['n_obs']:,} | "
                         f"{row['n_countries']} | {row['r_squared']:.3f} | "
                         f"{row['rho']:.3f} |")

        lines.append("")
        lines.append("## Key Interaction Coefficients")
        lines.append("")
        key_vars = ["Z_1_x_kaopen", "Z_2_x_kaopen", "Z_3_x_kaopen",
                     "Z_1_x_trade_open", "Z_2_x_trade_open", "Z_3_x_trade_open",
                     "Z_1_x_ers", "Z_2_x_ers", "Z_3_x_ers",
                     "trade_openness", "ers_index"]
        lines.append("| Variable | Model | Coef | SE | p-value |")
        lines.append("|----------|-------|------|----|---------|")
        for v in key_vars:
            sub = rdf[rdf["variable"] == v]
            for _, row in sub.iterrows():
                if pd.isna(row["p_value"]):
                    continue
                sig = stars(row["p_value"])
                lines.append(f"| {v} | {row['model']} | "
                             f"{row['coefficient']:.4f}{sig} | "
                             f"{row['std_error']:.4f} | {row['p_value']:.4f} |")

    md_path = OUT_TABLES / "phase5_kaopen_robustness.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Output: {md_path}")


# ═══════════════════════════════════════════════════════════════════════════
# Part E: NFA Split Detail
# ═══════════════════════════════════════════════════════════════════════════

def part_e_nfa_detail(df):
    """NFA creditor/debtor detail on income balance."""
    print("\n" + "=" * 70)
    print("PART E: NFA Split Detail")
    print("=" * 70)

    df = df.copy()
    demo_vars = ["Z_1", "Z_2", "Z_3"]
    eba_controls = ["fiscal_bal_gdp", "nfa_gdp_lag", "log_rel_opw", "kaopen"]
    eba_controls = [c for c in eba_controls if c in df.columns
                    and df[c].notna().sum() > 200]

    all_results = []

    # M3: Z → income_balance baseline (with nfa_gdp_lag)
    base_vars = demo_vars + eba_controls
    r = run_model(df, "income_balance_gdp", base_vars,
                  "M3: Z → Income Bal (baseline)")
    if r is not None:
        all_results.append(r)

    # M7: NFA split → income_balance
    creditor_vars = demo_vars + [c for c in eba_controls if c != "nfa_gdp_lag"]
    creditor_vars += ["nfa_positive", "nfa_negative"]
    r = run_model(df, "income_balance_gdp", creditor_vars,
                  "M7: NFA split → Income Bal")
    if r is not None:
        all_results.append(r)

    # M7+: NFA split on trade_balance (for comparison)
    r = run_model(df, "trade_balance_gdp", creditor_vars,
                  "M7-trade: NFA split → Trade Bal")
    if r is not None:
        all_results.append(r)

    # M7++: NFA split on CA (overall)
    r = run_model(df, "ca_gdp", creditor_vars,
                  "M7-ca: NFA split → CA")
    if r is not None:
        all_results.append(r)

    # Country-group correlations: NFA vs income_balance
    both = df.dropna(subset=["nfa_gdp_lag", "income_balance_gdp"]).copy()
    both["creditor"] = (both["nfa_gdp_lag"] > 0).astype(int)

    group_corrs = []
    for group, label in [(1, "Creditor (NFA>0)"), (0, "Debtor (NFA<0)")]:
        sub = both[both["creditor"] == group]
        if len(sub) > 30:
            c = sub["nfa_gdp_lag"].corr(sub["income_balance_gdp"])
            group_corrs.append({"Group": label, "N": len(sub),
                                "Countries": sub["iso3"].nunique(),
                                "Corr(NFA, IB)": c})

    # Write output
    lines = ["# Part E: NFA Creditor/Debtor Detail", ""]
    lines.append("Investigates whether Z captures return-composition effects "
                 "beyond what NFA/GDP measures.")
    lines.append("")

    if all_results:
        rdf = pd.concat(all_results, ignore_index=True)

        lines.append("## Model Comparison")
        lines.append("")
        lines.append("| Model | Dep Var | N | Countries | R² | ρ |")
        lines.append("|-------|---------|---|-----------|----|----|")
        for m in rdf["model"].unique():
            row = rdf[rdf["model"] == m].iloc[0]
            lines.append(f"| {m} | {row['dep_var']} | {row['n_obs']:,} | "
                         f"{row['n_countries']} | {row['r_squared']:.3f} | "
                         f"{row['rho']:.3f} |")

        lines.append("")
        lines.append("## Key Coefficients")
        lines.append("")
        key_vars = ["Z_1", "Z_2", "Z_3", "nfa_gdp_lag",
                     "nfa_positive", "nfa_negative"]
        lines.append("| Variable | Model | Coef | SE | p-value |")
        lines.append("|----------|-------|------|----|---------|")
        for v in key_vars:
            sub = rdf[rdf["variable"] == v]
            for _, row in sub.iterrows():
                if pd.isna(row["p_value"]):
                    continue
                sig = stars(row["p_value"])
                lines.append(f"| {v} | {row['model']} | "
                             f"{row['coefficient']:.4f}{sig} | "
                             f"{row['std_error']:.4f} | {row['p_value']:.4f} |")

    if group_corrs:
        lines.append("")
        lines.append("## NFA-Income Balance Correlations by Group")
        lines.append("")
        lines.append("| Group | N | Countries | Corr(NFA, IB) |")
        lines.append("|-------|---|-----------|---------------|")
        for g in group_corrs:
            lines.append(f"| {g['Group']} | {g['N']:,} | {g['Countries']} | "
                         f"{g['Corr(NFA, IB)']:.3f} |")

    lines.append("")
    lines.append("## Interpretation")
    lines.append("")
    lines.append("NFA/GDP captures the stock of accumulated positions but not "
                 "the composition or return differentials across asset classes. "
                 "Demographics may predict the income balance through both the "
                 "accumulation channel (which NFA captures) and the "
                 "composition/return channel (which NFA does not). If Z remains "
                 "significant conditional on NFA, this is consistent with "
                 "demographics capturing return-composition effects beyond "
                 "what NFA/GDP measures.")

    md_path = OUT_TABLES / "phase5_nfa_detail.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Output: {md_path}")


# ═══════════════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════════════

def main():
    print("=" * 70)
    print("PHASE 5: Reviewer #2 Response — Robustness & Validation")
    print("=" * 70)

    df = load_panel()
    print(f"  Panel: {df['iso3'].nunique()} countries, {len(df):,} obs, "
          f"{df['year'].min()}-{df['year'].max()}")

    corr, n_val, n_countries = part_a_income_validation(df)
    part_b_fh_long_diff(df)
    part_c_fh_income_robustness(df)
    part_d_kaopen_robustness(df)
    part_e_nfa_detail(df)

    print("\n" + "=" * 70)
    print("PHASE 5 COMPLETE")
    print("=" * 70)
    print(f"  Income validation correlation: {corr:.3f} (N={n_val:,}, "
          f"{n_countries} countries)")
    print(f"  Output files:")
    for f in sorted(OUT_TABLES.glob("phase5_*.md")):
        print(f"    {f.name}")


if __name__ == "__main__":
    main()
